from typing import Dict

import flax.linen as nn
import gym
import numpy as np
from tqdm import trange


def original_evaluate(agent: nn.Module, env: gym.Env,
             num_episodes: int) -> Dict[str, float]:
    stats = {'return': [], 'length': [], 'success': []}

    for _ in trange(num_episodes, desc='evaluation', leave=False):
        observation, done = env.reset(), False

        while not done:
            action = agent.sample_actions(observation, temperature=0.0)
            observation, _, done, info = env.step(action)

        for k in stats.keys():
            stats[k].append(info['episode'][k])

    for k, v in stats.items():
        stats[k] = np.mean(v)

    return stats

def evaluate(agent: nn.Module, env: gym.Env,
             num_episodes: int) -> Dict[str, float]:
    stats = {'return': []}

    for _ in trange(num_episodes, desc='evaluation', leave=False):
        observation, done = env.reset(), False
        total_r=0
        while not done:
            action = agent.sample_actions(observation, temperature=0.0)
            observation, r, done, info = env.step(action)
            total_r+=r
        # for k in stats.keys():
        #     stats[k].append(info['episode'][k])
        stats['return'].append(total_r)
    for k, v in stats.items():
        stats[k] = np.mean(v)

    return stats
